import torch as tch
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

""" Custom nmRNN implementation using JIT """

import torch
from torch import nn, jit
import math


class RNNCell_LowRank_base(jit.ScriptModule): #(nn.Module):#
#     __constants__ = ['bias']
    
    def __init__(self, N_NM, input_size, hidden_size,  rank0, rank1, nonlinearity, bias, keepW0 = False):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.nonlinearity = nonlinearity
        self.N_nm = N_NM
        self.keepW0 = keepW0
        self.g = 10


        #recurrent
        amp_init = 0.1
        amp_pages = amp_init/np.sqrt(hidden_size)
        #pages NMs
        #pages = amp_pages * tch.randn(2, N_NM, rank1, hidden_size)
        #self.pages = tch.nn.Parameter(pages, requires_grad=True)
        if self.N_nm >0:
            self.pages = tch.nn.Parameter(torch.Tensor(2, N_NM, rank1, hidden_size))
        #page0 =  amp_pages * tch.randn(2, rank0, hidden_size)
        #self.page0 = nn.Parameter(page0, requires_grad=True)
        self.page0 = nn.Parameter(torch.Tensor(2, rank0, hidden_size))

        self.weight_ih = nn.Parameter(torch.Tensor(hidden_size, input_size))
        #self.weight_ih = nn.Parameter(torch.Tensor(hidden_size, input_size, N_NM))
        #self.weight_hh = nn.Parameter(torch.Tensor(hidden_size, hidden_size, N_NM))
        if self.N_nm >0:
            self.weight_h2nm = nn.Parameter(torch.Tensor(N_NM, hidden_size))
            self.weight_nm2nm = nn.Parameter(torch.Tensor(N_NM, N_NM))
        #if keepW0:
        #    self.weight0_hh = nn.Parameter(torch.Tensor(hidden_size, hidden_size), requires_grad = True)
        #else:
        #    self.weight0_hh = nn.Parameter(torch.Tensor(hidden_size, hidden_size), requires_grad = False)
            #self.register_parameter('weight0_hh', None)
        
        if bias:
            self.bias = nn.Parameter(torch.Tensor(hidden_size))
        else:
            self.register_parameter('bias', None)
            
        self.reset_parameters()
        
    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight_ih, a=math.sqrt(5))    #, nonlinearity=nonlinearity)
        nn.init.kaiming_uniform_(self.page0, a=self.g/math.sqrt(self.hidden_size))    #, nonlinearity=nonlinearity)
        if self.N_nm>0:
            nn.init.sparse_(self.weight_h2nm, 0.1)
            nn.init.zeros_(self.weight_nm2nm)
            nn.init.kaiming_uniform_(self.pages, a=self.g/math.sqrt(self.hidden_size))   

        #if self.keepW0:
        #    nn.init.kaiming_uniform_(self.weight0_hh, a=math.sqrt(5))
        #else:
        #    nn.init.zeros_(self.weight0_hh)
        
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight_ih)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)


            
class RNNCellLowRank(RNNCell_LowRank_base):  # Euler integration of rate-neuron network dynamics 
    def __init__(self, N_nm, input_size, hidden_size, rank0, rank1, nonlinearity = None, decay = 0, bias = True, keepW0 = True):
        super().__init__(N_nm, input_size, hidden_size, rank0, rank1, nonlinearity, bias)
        self.decay = decay    #  torch.exp( - dt/tau )
        self.N_nm = N_nm

    def forward(self, input, hiddenCombined):
        
        #calculating pages
        self.weight0_hh = tch.einsum('ri,rj->ij', self.page0[0], self.page0[1])

   
        # start by disentangling the NMs from the Hidden Units
        if self.N_nm>0:
            outer_prod = tch.einsum('lri,krj->lkij', self.pages[0], self.pages[1]) #extra calculation here
            self.weight_hh = tch.einsum('nnij->ijn', outer_prod) #take the diagonal
            hidden = hiddenCombined[:,:,0:-self.N_nm]
            nm = hiddenCombined[:,:,-self.N_nm::]
            #print(hiddenCombined.shape, hidden.shape, nm.shape)
        else:
            hidden = hiddenCombined
            nm = None
        if self.bias == None:
            if nm != None:
                activity = self.nonlinearity(input @ self.weight_ih.t() + torch.einsum('tbj, ijk, tbk -> bi', hidden, self.weight_hh, nm) + hidden @ self.weight0_hh.t())
            else:
                activity = self.nonlinearity(input @ self.weight_ih.t() +  hidden @ self.weight0_hh.t())
        else:
            if nm != None:
                activity = self.nonlinearity(input @ self.weight_ih.t() + torch.einsum('bj, ijk, bk -> bi', hidden, self.weight_hh, nm) + hidden @ self.weight0_hh.t() + self.bias)
            else:
                activity = self.nonlinearity(input @ self.weight_ih.t() +  hidden @ self.weight0_hh.t() + self.bias)
        if nm != None:
            activity_nm = self.nonlinearity(hidden @ self.weight_h2nm.t() + nm @ self.weight_nm2nm.t())
            nm = self.decay * nm + (1-self.decay) * activity_nm
        hidden   = self.decay * hidden + (1 - self.decay) * activity
        if self.N_nm >0:
            tmp = torch.cat([hidden, nm], dim = 2)
        else:
            tmp = hidden
        return tmp

class RNNLowRankLayer(nn.Module): 
    """This behavses very similarly to nn.RNN() but returns the NM state appended to the hiddenstate along the dimension of the tensor."""
    def __init__(self, N_nm, input_size, hidden_size, rank0, rank1, nonlinearity, decay = 0.9, bias = False, keepW0 = False):
        super().__init__()
        self.rnncell = RNNCellLowRank(N_nm, input_size, hidden_size, rank0, rank1, nonlinearity = nonlinearity, decay = decay, bias = bias, keepW0 = keepW0)
        self.N_nm = N_nm

    def forward(self, input, initH):
        #print('in the layer ', initH[0].shape, initH[1].shape)
        inputs = input.unbind(0)     # inputs has dimension [Time, batch n_input]
        hidden = initH      # initial state has dimension [1, batch, n_rnn]
        outputs = []
        nm_out = []
        for i in range(len(inputs)):  # looping over the time dimension 
            hidden = self.rnncell(inputs[i], hidden)
            outputs += [hidden.squeeze(0)]       # vanilla RNN directly outputs the hidden state
        return torch.stack(outputs), hidden



""" Custom nmRNN implementation using JIT """

import torch
from torch import nn, jit
import math


class nmRNNCell_LowRank_base(jit.ScriptModule): #(nn.Module):#
#     __constants__ = ['bias']
    
    def __init__(self, N_NM, input_size, hidden_size,  rank0, rank1, nonlinearity, bias, keepW0 = False):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.nonlinearity = nonlinearity
        self.N_nm = N_NM
        self.keepW0 = keepW0
        self.g = 10


        #recurrent
        amp_init = 0.1
        amp_pages = amp_init/np.sqrt(hidden_size)
        #pages NMs
        #pages = amp_pages * tch.randn(2, N_NM, rank1, hidden_size)
        #self.pages = tch.nn.Parameter(pages, requires_grad=True)
        self.pages = tch.nn.Parameter(torch.Tensor(2, N_NM, rank1, hidden_size))
        #page0 =  amp_pages * tch.randn(2, rank0, hidden_size)
        #self.page0 = nn.Parameter(page0, requires_grad=True)
        self.page0 = nn.Parameter(torch.Tensor(2, rank0, hidden_size))

        self.weight_ih = nn.Parameter(torch.Tensor(hidden_size, input_size))
        #self.weight_ih = nn.Parameter(torch.Tensor(hidden_size, input_size, N_NM))
        #self.weight_hh = nn.Parameter(torch.Tensor(hidden_size, hidden_size, N_NM))
        self.weight_h2nm = nn.Parameter(torch.Tensor(N_NM, hidden_size))
        self.weight_nm2nm = nn.Parameter(torch.Tensor(N_NM, N_NM))
        #if keepW0:
        #    self.weight0_hh = nn.Parameter(torch.Tensor(hidden_size, hidden_size), requires_grad = True)
        #else:
        #    self.weight0_hh = nn.Parameter(torch.Tensor(hidden_size, hidden_size), requires_grad = False)
            #self.register_parameter('weight0_hh', None)
        
        if bias:
            self.bias = nn.Parameter(torch.Tensor(hidden_size))
        else:
            self.register_parameter('bias', None)
            
        self.reset_parameters()
        
    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight_ih, a=math.sqrt(5))    #, nonlinearity=nonlinearity)
        nn.init.kaiming_uniform_(self.pages, a=self.g/math.sqrt(self.hidden_size))    #, nonlinearity=nonlinearity)
        nn.init.kaiming_uniform_(self.page0, a=self.g/math.sqrt(self.hidden_size))    #, nonlinearity=nonlinearity)
        nn.init.sparse_(self.weight_h2nm, 0.1)
        nn.init.zeros_(self.weight_nm2nm)

        #if self.keepW0:
        #    nn.init.kaiming_uniform_(self.weight0_hh, a=math.sqrt(5))
        #else:
        #    nn.init.zeros_(self.weight0_hh)
        
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight_ih)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)


            
class nmRNNCellLowRank(nmRNNCell_LowRank_base):  # Euler integration of rate-neuron network dynamics 
    def __init__(self, N_nm, input_size, hidden_size, rank0, rank1, nonlinearity = None, decay = 0, bias = True, keepW0 = True):
        super().__init__(N_nm, input_size, hidden_size, rank0, rank1, nonlinearity, bias)
        self.decay = decay    #  torch.exp( - dt/tau )
        self.N_nm = N_nm

    def forward(self, input, hiddenCombined):
        
        #calculating pages
        outer_prod = tch.einsum('lri,krj->lkij', self.pages[0], self.pages[1]) #extra calculation here
        self.weight_hh = tch.einsum('nnij->ijn', outer_prod) #take the diagonal
        self.weight0_hh = tch.einsum('ri,rj->ij', self.page0[0], self.page0[1])

   
        # start by disentangling the NMs from the Hidden Units
        if self.N_nm>0:
            hidden = hiddenCombined[:,:,0:-self.N_nm]
            nm = hiddenCombined[:,:,-self.N_nm::]
            #print(hiddenCombined.shape, hidden.shape, nm.shape)
        else:
            hidden = hiddenCombined
            nm = None
        if self.bias == None:
            if nm != None:
                activity = self.nonlinearity(input @ self.weight_ih.t() + torch.einsum('tbj, ijk, tbk -> bi', hidden, self.weight_hh, nm) + hidden @ self.weight0_hh.t())
            else:
                activity = self.nonlinearity(input @ self.weight_ih.t() +  hidden @ self.weight0_hh.t())
        else:
            if nm != None:
                activity = self.nonlinearity(input @ self.weight_ih.t() + torch.einsum('bj, ijk, bk -> bi', hidden, self.weight_hh, nm) + hidden @ self.weight0_hh.t() + self.bias)
            else:
                activity = self.nonlinearity(input @ self.weight_ih.t() +  hidden @ self.weight0_hh.t() + self.bias)
        if nm != None:
            activity_nm = self.nonlinearity(hidden @ self.weight_h2nm.t() + nm @ self.weight_nm2nm.t())
            nm = self.decay * nm + (1-self.decay) * activity_nm
        hidden   = self.decay * hidden + (1 - self.decay) * activity
        return torch.cat([hidden, nm], dim = 2)

class nmRNNLowRankLayer(nn.Module): 
    """This behavses very similarly to nn.RNN() but returns the NM state appended to the hiddenstate along the dimension of the tensor."""
    def __init__(self, N_nm, input_size, hidden_size, rank0, rank1, nonlinearity, decay = 0.9, bias = False, keepW0 = False):
        super().__init__()
        self.rnncell = nmRNNCellLowRank(N_nm, input_size, hidden_size, rank0, rank1, nonlinearity = nonlinearity, decay = decay, bias = bias, keepW0 = keepW0)
        self.N_nm = N_nm

    def forward(self, input, initH):
        #print('in the layer ', initH[0].shape, initH[1].shape)
        inputs = input.unbind(0)     # inputs has dimension [Time, batch n_input]
        hidden = initH      # initial state has dimension [1, batch, n_rnn]
        outputs = []
        nm_out = []
        for i in range(len(inputs)):  # looping over the time dimension 
            hidden = self.rnncell(inputs[i], hidden)
            outputs += [hidden.squeeze(0)]       # vanilla RNN directly outputs the hidden state
        return torch.stack(outputs), hidden

class nmRNNCell_base(jit.ScriptModule): #(nn.Module):#
#     __constants__ = ['bias']
    
    def __init__(self, N_NM, input_size, hidden_size, nonlinearity, bias, keepW0 = False):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.nonlinearity = nonlinearity
        self.N_nm = N_NM
        self.keepW0 = keepW0
        self.g = 10

        self.weight_ih = nn.Parameter(torch.Tensor(hidden_size, input_size))
        #self.weight_ih = nn.Parameter(torch.Tensor(hidden_size, input_size, N_NM))
        self.weight_hh = nn.Parameter(torch.Tensor(hidden_size, hidden_size, N_NM))
        self.weight_h2nm = nn.Parameter(torch.Tensor(N_NM, hidden_size))
        self.weight_nm2nm = nn.Parameter(torch.Tensor(N_NM, N_NM))
        if keepW0:
            self.weight0_hh = nn.Parameter(torch.Tensor(hidden_size, hidden_size), requires_grad = True)
        else:
            self.weight0_hh = nn.Parameter(torch.Tensor(hidden_size, hidden_size), requires_grad = False)
            #self.register_parameter('weight0_hh', None)
        
        if bias:
            self.bias = nn.Parameter(torch.Tensor(hidden_size))
        else:
            self.register_parameter('bias', None)
            
        self.reset_parameters()
        
    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight_ih, a=math.sqrt(5))    #, nonlinearity=nonlinearity)
        nn.init.kaiming_uniform_(self.weight_hh, a=self.g/math.sqrt(self.hidden_size))    #, nonlinearity=nonlinearity)
        nn.init.sparse_(self.weight_h2nm, 0.1)
        nn.init.zeros_(self.weight_nm2nm)

        if self.keepW0:
            nn.init.kaiming_uniform_(self.weight0_hh, a=math.sqrt(5))
        else:
            nn.init.zeros_(self.weight0_hh)
        
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight_ih)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)


            
class nmRNNCell(nmRNNCell_base):  # Euler integration of rate-neuron network dynamics 
    def __init__(self, N_nm, input_size, hidden_size, nonlinearity = None, decay = 0, bias = True, keepW0 = True):
        super().__init__(N_nm, input_size, hidden_size, nonlinearity, bias)
        self.decay = decay    #  torch.exp( - dt/tau )
        self.N_nm = N_nm

    def forward(self, input, hiddenCombined):
        # start by disentangling the NMs from the Hidden Units
        if self.N_nm>0:
            hidden = hiddenCombined[:,:,0:-self.N_nm]
            nm = hiddenCombined[:,:,-self.N_nm::]
            #print(hiddenCombined.shape, hidden.shape, nm.shape)
        else:
            hidden = hiddenCombined
            nm = None
        if self.bias == None:
            if nm != None:
                activity = self.nonlinearity(input @ self.weight_ih.t() + torch.einsum('tbj, ijk, tbk -> bi', hidden, self.weight_hh, nm) + hidden @ self.weight0_hh.t())
            else:
                activity = self.nonlinearity(input @ self.weight_ih.t() +  hidden @ self.weight0_hh.t())
        else:
            if nm != None:
                activity = self.nonlinearity(input @ self.weight_ih.t() + torch.einsum('bj, ijk, bk -> bi', hidden, self.weight_hh, nm) + hidden @ self.weight0_hh.t() + self.bias)
            else:
                activity = self.nonlinearity(input @ self.weight_ih.t() +  hidden @ self.weight0_hh.t() + self.bias)
        if nm != None:
            activity_nm = self.nonlinearity(hidden @ self.weight_h2nm.t() + nm @ self.weight_nm2nm.t())
            nm = self.decay * nm + (1-self.decay) * activity_nm
        hidden   = self.decay * hidden + (1 - self.decay) * activity
        return torch.cat([hidden, nm], dim = 2)

class nmRNNLayer(nn.Module): 
    """This behavses very similarly to nn.RNN() but returns the NM state appended to the hiddenstate along the dimension of the tensor."""
    def __init__(self, N_nm, input_size, hidden_size, nonlinearity, decay = 0.9, bias = False, keepW0 = False):
        super().__init__()
        self.rnncell = nmRNNCell(N_nm, input_size, hidden_size, nonlinearity = nonlinearity, decay = decay, bias = bias, keepW0 = keepW0)
        self.N_nm = N_nm

    def forward(self, input, initH):
        #print('in the layer ', initH[0].shape, initH[1].shape)
        inputs = input.unbind(0)     # inputs has dimension [Time, batch n_input]
        hidden = initH      # initial state has dimension [1, batch, n_rnn]
        outputs = []
        nm_out = []
        for i in range(len(inputs)):  # looping over the time dimension 
            hidden = self.rnncell(inputs[i], hidden)
            outputs += [hidden.squeeze(0)]       # vanilla RNN directly outputs the hidden state
        return torch.stack(outputs), hidden


class NM_RNN(nn.Module):
    '''
    MSB: 04/10/2024. This is a dirty and slow implementation of the nmRNN for development which appears to work well enough to get started
    
    ToDos: 
    *see how much can be gained with JIT compliation
    *Is there much to be gained with a torch.einsum implementation on the tensor parameterization?
    
    '''
    def __init__(self, N_NM, N_cells, N_input=1, N_out=1, nbatch=1, linearize = False):
        super(NM_RNN, self).__init__()
        self.N_NM = N_NM
        self.N_cells = N_cells
        self.N_input = N_input
        self.N_out = N_out
        self.nbatch = nbatch
        
        self.pages = nn.ModuleList([nn.Linear(self.N_input + self.N_cells, self.N_cells, bias=False) for _ in range(self.N_NM)])
        #self.Decode = nn.Linear(self.N_cells, self.N_out, bias=False)
        self.NM2NM = nn.Linear(self.N_NM, self.N_NM, bias=False) #
        self.r2NM = nn.Linear(self.N_cells, self.N_NM, bias=False) 
        self.register_buffer('hidden', torch.rand(nbatch, self.N_cells))
        self.register_buffer('NM', torch.rand(nbatch, self.N_NM))
        self.register_buffer('dt', torch.tensor(0.1))
        
    def nonlin(self, x):
        # will want to improve the efficiency of this, but here is a placeholder for a better implementation
        if linearize:
            return x
        else:
            return F.relu(F.tanh(x))

    def NM_nonlin(self, x):
        return F.relu(F.tanh(x))
        
    def forward(self, inputs, hidden_init):
        inputs = inputs.permute(1,0,2)
        with torch.no_grad():
            nbatch, ntime, _ = inputs.shape
            hidden_list = []
            hidden_list.append(hidden_init.detach().squeeze(0))
            NM_list = []
            NM_list.append(torch.ones(nbatch, self.N_NM))
            output_list = []
            output_list.append(hidden_list[-1])
        
        for t in range(ntime - 1):
            input = inputs[:, t, :]
            #print(NM_list[-1].shape, NM_list[-1][:,0].unsqueeze(1).shape)
            #print(input.shape, hidden_list[-1].shape)
            ztemp = []
            for i in range(self.N_NM):
                ztemp.append( NM_list[-1][:,i].unsqueeze(1) * self.pages[i](torch.cat((input, hidden_list[-1]), dim=1)))
            z = torch.sum(torch.stack(ztemp,axis = 2),axis = 2)
            # make sure this is no over
            hidden_list.append( hidden_list[-1] + self.dt * (-hidden_list[-1] + self.nonlin(z)))
            NM_list.append( NM_list[-1] + self.dt * (-NM_list[-1] + self.NM_nonlin(self.r2NM(hidden_list[-1]) + self.NM2NM(NM_list[-1]))))
            output_list.append( hidden_list[-1])

            hidden = hidden_list[-1]#torch.stack(hidden_list, dim = 1)
            NM = torch.stack(NM_list, dim = 1)
            output = torch.stack(output_list, dim = 1)
            self.NM = NM_list[-1].detach()
            #print(output.shape)
            output = output.permute(1,0,2)
        return output, hidden



class NMsLowRankRNN(nn.Module):
    '''
    Low-rank Neuromodulated RNNs module
    '''
    def __init__(self, N_cells, ranks, nbatch, device='cpu'):
        super(NMsLowRankRNN, self).__init__()
        #device
        self.device = device

        #number of units
        self.n_NMs = N_cells['n_NMs'] #number of NMs
        self.N_NMs = N_cells['N_NMs'] #number of cells in the NM networks 
        self.N_ctx = N_cells['N_ctx'] # number of cells on the CTX network
        self.nbatch = nbatch

        #rank low-rank projections
        self.rank_ctx_rec = ranks['rank_ctx_rec']
        self.rank_NMs_rec = ranks['rank_NMs_rec']
        self.rank_NMs_ctx = ranks['rank_NMs_ctx']
        
        #recurrent Ctx
        amp_init = 0.1
        amp_pages = amp_init/np.sqrt(self.N_ctx)
        #pages NMs
        pages = amp_pages * tch.randn(2, self.n_NMs, self.rank_ctx_rec, self.N_ctx)
        self.pages = tch.nn.Parameter(pages, requires_grad=True)
        outer_prod = tch.einsum('lri,krj->lkij', self.pages[0], self.pages[1]) #extra calculation here
        self.J_pages = tch.einsum('nnij->nij', outer_prod) #take the diagonal
        #reccurrent NOT neuromodulated
        patterns_ctx0 =  amp_pages * tch.randn(2, self.rank_ctx_rec, self.N_ctx)
        self.patterns_ctx0 = tch.nn.Parameter(patterns_ctx0, requires_grad=True)
        self.J_ctx0 = tch.einsum('ri,rj->ij', self.patterns_ctx0[0], self.patterns_ctx0[1])

        #recurrent NMs
        amp = amp_init/np.sqrt(self.N_NMs)
        patterns_NMs_rec = amp * tch.randn(2, self.n_NMs, self.rank_NMs_rec, self.N_NMs) 
        self.patterns_NMs_rec = tch.nn.Parameter(patterns_NMs_rec, requires_grad=True)
        outer_prod = tch.einsum('lri,krj->lkij', self.patterns_NMs_rec[0], self.patterns_NMs_rec[1]) #extra calculation here
        self.J_NMs = tch.einsum('nnij->nij', outer_prod) #take the diagonal

        #ctx to NMs projections 
        amp = amp_init/np.sqrt(self.N_ctx)
        patterns_pre = amp * tch.randn(self.n_NMs, self.rank_NMs_ctx, self.N_ctx)
        amp = amp_init/np.sqrt(self.N_NMs)
        patterns_post = amp * tch.randn(self.n_NMs, self.rank_NMs_ctx, self.N_NMs)  
        self.patterns_ctx = tch.nn.Parameter(patterns_pre, requires_grad=True)
        self.patterns_NMs = tch.nn.Parameter(patterns_post , requires_grad=True)
        outer_prod = tch.einsum('lri,krj->lkij', self.patterns_NMs, self.patterns_ctx) #extra calculation here
        self.J_NMs_ctx = tch.einsum('nnij->nij', outer_prod) #take the diagonal

        #NM to Ctx
        amp = amp_init/np.sqrt(self.N_NMs)
        J_ctx_NMs = amp * tch.randn(self.n_NMs, self.N_NMs)
        self.J_ctx_NMs = tch.nn.Parameter(J_ctx_NMs, requires_grad=True)

        #biases
        self.bias_ctx = tch.nn.Parameter(0.1 *  tch.randn(self.N_ctx), requires_grad=True)
        self.bias_NMs =  tch.nn.Parameter(0.1 *  tch.randn(self.n_NMs, self.N_NMs), requires_grad=True)

        #decoder
        self.decoder_ctx = tch.nn.Parameter(0.1 * tch.randn(self.N_ctx), requires_grad=True)
        self.encoder_ctx = tch.nn.Parameter(0.1 * tch.randn(self.N_ctx), requires_grad=True)
        self.dt = 0.1
        self.to(self.device)

    def nonlin(self, x):
        return F.tanh(x)

    def NM_nonlin(self, x):
        '''Neuromodulator signal 
           should be between 0 and 1'''
        return 0.5 * (F.tanh(x) + 1.) 

    def update_ctx(self, h_ctx, h_nm, input_signal_t):
        """ b: batch index
            n,k: neuromodulator index
            i,j: neuron index """
        r_ctx = self.nonlin(h_ctx)
        r_NMs = self.nonlin(h_nm)
        proj_NMs = tch.einsum('nj,bkj->bnk',self.J_ctx_NMs, r_NMs)
        proj_NMs = tch.einsum('bnn->bn', proj_NMs) #take the diagonal
        mean_NMs = self.NM_nonlin(proj_NMs)
        output_ctx = tch.einsum('j,bj->b', self.decoder_ctx, r_ctx)
        #currents
        current_signal = tch.einsum('i,b->bi', self.encoder_ctx, input_signal_t)
        current_input = self.bias_ctx
        current_rec = tch.einsum('ij,bj->bi', self.J_ctx0, r_ctx)
        sum_pages =  tch.einsum('nij,bn->bij', self.J_pages, mean_NMs) 
        current_pages = tch.einsum('bij, bj->bi', sum_pages, r_ctx)
        total_current = current_pages + current_rec + current_input + current_signal
        #euler update
        h_ctx = h_ctx + self.dt * total_current
        return h_ctx, mean_NMs, output_ctx

    def update_NMs(self, h_ctx, h_nm):
        '''Euler update for the nNMs 
            neuromodulator RNNs: this 
            updates nNMs networks at
            the same time'''
        current_input = self.bias_NMs
        current_rec =  tch.einsum('nij,blj->bnli',  self.J_NMs, self.nonlin(h_nm))
        current_rec = tch.einsum('bnni->bni', current_rec) # take the diagonal
        current_ctx =  tch.einsum('nij,bj->bni',self.J_NMs_ctx, self.nonlin(h_ctx))
        total_current = current_ctx + current_rec + current_input
        h_nm = h_nm + self.dt * total_current
        return h_nm
        
    def forward(self, input_signal):
        nbatch, ntime = input_signal.shape
        assert nbatch==self.nbatch

        #tensors to save variables
        rates_ctx = tch.zeros(self.nbatch, ntime, self.N_ctx).to(self.device)
        output_NMs = tch.zeros(self.nbatch, ntime, self.n_NMs).to(self.device)
        output_ctx = tch.zeros(self.nbatch, ntime).to(self.device)

        #initial conditions
        t=0
        h_ctx_n = 0.01 * tch.randn(self.nbatch, self.N_ctx).to(self.device)
        h_nm_n = 0.01 * tch.randn(self.nbatch, self.n_NMs, self.N_NMs).to(self.device)
        for i in range(ntime):
            h_ctx_np1, mean_NMs, out_ctx = self.update_ctx(h_ctx_n, h_nm_n, input_signal[:,i])
            h_nm_np1 = self.update_NMs(h_ctx_n, h_nm_n)
            h_ctx_n = h_ctx_np1
            h_nm_n = h_nm_np1
            rates_ctx[:, i,:] = self.nonlin(h_ctx_n)
            output_NMs[:, i, :] = mean_NMs
            output_ctx[:, i] = out_ctx
            t+=self.dt
        results = {'rates_ctx' : rates_ctx,
                   'output_NMs' :  output_NMs,
                   'output_ctx' :  output_ctx}
        return results